{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hidden Trigger Backdoor Attack Example\n",
    "This notebook shows how to use the ART implementation of the Hidden Trigger Backdoor attack to poison a model. The goal of the attack is to create poisoned inputs within a target class with no visible trigger that can be used during model finetuning to poison the model. After finetuning, adding a trigger to certain inputs not in the target class will cause the model to reliably misclassify those inputs as the target class. Full details of the attack can be found at https://arxiv.org/abs/1910.00033."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Temporary directory not created yet\n",
      "Temporary directory: /tmp/tmp9vc5phno\n"
     ]
    }
   ],
   "source": [
    "import os, sys\n",
    "from os.path import abspath\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "\n",
    "import tempfile\n",
    "# Create a temporary directory for the model checkpoint. Remove the exising one if this cell is rerun\n",
    "try:\n",
    "    temp_model_dir.cleanup()\n",
    "except NameError:\n",
    "    print(\"Temporary directory not created yet\")\n",
    "finally:\n",
    "    temp_model_dir = tempfile.TemporaryDirectory() \n",
    "    print(\"Temporary directory:\", temp_model_dir.name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Dataset\n",
    "We will load the CIFAR10 dataset as well as define the mean and std as done in the original paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from art.utils import load_dataset\n",
    "(x_train, y_train), (x_test, y_test), min_, max_ = load_dataset('cifar10')\n",
    "# Step 1a: Swap axes to PyTorch's NCHW format\n",
    "\n",
    "x_train = np.transpose(x_train, (0, 3, 1, 2)).astype(np.float32)\n",
    "x_test = np.transpose(x_test, (0, 3, 1, 2)).astype(np.float32)\n",
    "mean = (0.4914, 0.4822, 0.4465) \n",
    "std = (0.2023, 0.1994, 0.201)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define and Train Model\n",
    "We will use the modified Alexnet model from the original paper as our model to poison. We temporarily save a model checkpoint to disk. This checkpoint will be used during the finetuning step to reset the model to it's original state before the attack. For some reason, some model state variable seems to change after the attack is performed. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from art.estimators.classification import PyTorchClassifier\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "num_classes=10\n",
    "feature_size=4096\n",
    "model=nn.Sequential(\n",
    "        nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        nn.Flatten(),\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(256 * 1 * 1, 4096),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(4096, feature_size),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.Linear(feature_size, num_classes)\n",
    ")\n",
    "\n",
    "# Define the ART Estimator\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=2e-4)\n",
    "classifier = PyTorchClassifier(\n",
    "    model=model,\n",
    "    clip_values=(min_, max_),\n",
    "    loss=criterion,\n",
    "    optimizer=optimizer,\n",
    "    input_shape=(3, 32, 32),\n",
    "    nb_classes=10,\n",
    "    preprocessing=(mean, std)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.01\n",
      "0.001\n"
     ]
    }
   ],
   "source": [
    "# Train the model\n",
    "\n",
    "classifier.fit(x_train, y_train, nb_epochs=100, batch_size=128, verbose=True)\n",
    "for param_group in classifier.optimizer.param_groups:\n",
    "    print(param_group[\"lr\"])\n",
    "    param_group[\"lr\"] *= 0.1\n",
    "classifier.fit(x_train, y_train, nb_epochs=50, batch_size=128, verbose=True)\n",
    "for param_group in classifier.optimizer.param_groups:\n",
    "    print(param_group[\"lr\"])\n",
    "    param_group[\"lr\"] *= 0.1\n",
    "classifier.fit(x_train, y_train, nb_epochs=50, batch_size=128, verbose=True)\n",
    "torch.save(model.state_dict(), temp_model_dir.name + \"/htbd_model.pth\") # Write the checkpoint to a temporary directory\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy on benign test examples: 77.19%\n"
     ]
    }
   ],
   "source": [
    "predictions = classifier.predict(x_test)\n",
    "accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n",
    "print(\"Accuracy on benign test examples: {}%\".format(accuracy * 100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate the Poison\n",
    "We now will generate the poison using the hidden trigger backdoor attack. First, we define the target and source classes as well as the backdoor trigger. The target class will be the class we want to insert poisoned data into. The source class will be the class we will add a trigger to in order to cause misclassification into the target.\n",
    "\n",
    "The backdoor trigger will be a small image patch inserted into the source class images. At test time, we should be able to use this trigger to cause the classifier to misclassify source class images with the trigger added as the target class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from art.attacks.poisoning.backdoor_attack import PoisoningAttackBackdoor\n",
    "target = np.array([0,0,0,0,1,0,0,0,0,0])\n",
    "source = np.array([0,0,0,1,0,0,0,0,0,0])\n",
    "\n",
    "# Backdoor Trigger Parameters\n",
    "patch_size = 8\n",
    "x_shift = 32 - patch_size - 5\n",
    "y_shift = 32 - patch_size - 5\n",
    "\n",
    "# Define the backdoor poisoning object. Calling backdoor.poison(x) will insert the trigger into x.\n",
    "from art.attacks.poisoning import perturbations\n",
    "def mod(x):\n",
    "    original_dtype = x.dtype\n",
    "    x = perturbations.insert_image(x, backdoor_path=\"../../utils/data/backdoors/htbd.png\",\n",
    "                                   channels_first=True, random=False, x_shift=x_shift, y_shift=y_shift,\n",
    "                                   size=(patch_size,patch_size), mode='RGB', blend=1)\n",
    "    return x.astype(original_dtype)\n",
    "backdoor = PoisoningAttackBackdoor(mod)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we run the attack. `eps` controls how much the target images can be perturbed with respect to an l-infinity distance. `feature_layer` dicates with layer's output will be used to define the attack's loss. It can either be the name of the layer or the layer index according to the ART estimator. `poison_percent` controls how many poisoned samples will be generated based on the size of the input data.\n",
    "\n",
    "The attack will return poisoned inputs of the target class and the indicies in the data that those poisoned inputs should replace."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fee71bb59beb4d6faf89b395d458a4c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Hidden Trigger:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  0 | batch: 0 | i:     0 | LR: 0.00100 |                         Loss Val: 3277.405 | Loss Avg: 3277.405\n",
      "Epoch:  0 | batch: 0 | i:   100 | LR: 0.00100 |                         Loss Val: 13.257 | Loss Avg: 128.395\n",
      "Max_Loss: 9.954315185546875\n",
      "Epoch:  0 | batch: 1 | i:     0 | LR: 0.00100 |                         Loss Val: 3344.732 | Loss Avg: 124.482\n",
      "Max_Loss: 9.848298072814941\n",
      "Number of poison samples generated: 50\n"
     ]
    }
   ],
   "source": [
    "from art.attacks.poisoning import HiddenTriggerBackdoor\n",
    "poison_attack = HiddenTriggerBackdoor(classifier, eps=16/255, target=target, source=source, feature_layer=19, backdoor=backdoor, decay_coeff = .95, decay_iter = 2000, max_iter=5000, batch_size=25, poison_percent=.01)\n",
    "\n",
    "poison_data, poison_indices = poison_attack.poison(x_train, y_train)\n",
    "print(\"Number of poison samples generated:\", len(poison_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune the Model\n",
    "Now, we must finetune the model using the poisoned data and a small number of clean training inputs.  Here, we randomly select an equal number of training inputs from each of the classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create finetuning dataset\n",
    "dataset_size = 2500\n",
    "num_classes = 10\n",
    "num_per_class = dataset_size/num_classes\n",
    "\n",
    "poison_dataset_inds = []\n",
    "\n",
    "for i in range(num_classes):\n",
    "    class_inds = np.where(np.argmax(y_train,axis=1) == i)[0]\n",
    "    num_select = int(num_per_class)\n",
    "    if np.argmax(target) == i:\n",
    "        num_select = int(num_select - len(poison_data))\n",
    "        poison_dataset_inds.append(poison_indices)\n",
    "    poison_dataset_inds.append(np.random.choice(class_inds, num_select, replace=False))\n",
    "    \n",
    "poison_dataset_inds = np.concatenate(poison_dataset_inds)\n",
    "\n",
    "poison_x = np.copy(x_train)\n",
    "poison_x[poison_indices] = poison_data\n",
    "poison_x = poison_x[poison_dataset_inds]\n",
    "\n",
    "poison_y = np.copy(y_train)[poison_dataset_inds]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We reload the model from the temporary checkpoint created before to reset the model to its original state after training just in case something changed from the attack. We freeze all of the layers and then replace the final layer with an randomly initialized layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "# Load model again\n",
    "num_classes=10\n",
    "feature_size=4096\n",
    "model=nn.Sequential(\n",
    "        nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "        nn.Flatten(),\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(256 * 1 * 1, 4096),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.Dropout(),\n",
    "        nn.Linear(4096, feature_size),\n",
    "        nn.ReLU(inplace=True),\n",
    "        nn.Linear(feature_size, num_classes)\n",
    ")\n",
    "model.load_state_dict(torch.load(temp_model_dir.name+\"/htbd_model.pth\"))\n",
    "temp_model_dir.cleanup() # Remove the temporary directory after loading the checkpoint\n",
    "\n",
    "# Freeze the layers up to the last layer\n",
    "for i, param in enumerate(model.parameters()):\n",
    "    param.requires_grad = False\n",
    "\n",
    "\n",
    "num_classes=10\n",
    "feature_size=4096\n",
    "model[20] = nn.Linear(feature_size, num_classes)\n",
    "\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=2e-4)\n",
    "\n",
    "classifier = PyTorchClassifier(\n",
    "    model=model,\n",
    "    clip_values=(min_, max_),\n",
    "    loss=criterion,\n",
    "    optimizer=optimizer,\n",
    "    input_shape=(3, 32, 32),\n",
    "    nb_classes=10,\n",
    "    preprocessing=(mean, std)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Epoch 0\n",
      "Accuracy on benign test examples: 11.82%\n",
      "Accuracy on benign trigger test examples: 0.0%\n",
      "Accuracy on poison trigger test examples: 0.0%\n",
      "Success on poison trigger test examples: 0.4%\n",
      "\n",
      "Training Epoch 5\n",
      "Accuracy on benign test examples: 73.19%\n",
      "Accuracy on benign trigger test examples: 44.4%\n",
      "Accuracy on poison trigger test examples: 24.0%\n",
      "Success on poison trigger test examples: 62.2%\n",
      "\n",
      "Training Epoch 10\n",
      "Accuracy on benign test examples: 72.17%\n",
      "Accuracy on benign trigger test examples: 38.3%\n",
      "Accuracy on poison trigger test examples: 15.7%\n",
      "Success on poison trigger test examples: 73.6%\n",
      "\n",
      "Training Epoch 15\n",
      "Accuracy on benign test examples: 72.35000000000001%\n",
      "Accuracy on benign trigger test examples: 39.6%\n",
      "Accuracy on poison trigger test examples: 17.299999999999997%\n",
      "Success on poison trigger test examples: 71.89999999999999%\n",
      "\n"
     ]
    }
   ],
   "source": [
    "trigger_test_inds = np.where(np.all(y_test == source, axis=1))[0]\n",
    "\n",
    "lr_factor = .1\n",
    "lr_schedule = [5, 10, 15]\n",
    "\n",
    "test_poisoned_samples, test_poisoned_labels  = backdoor.poison(x_test[trigger_test_inds], y_test[trigger_test_inds])\n",
    "\n",
    "for i in range(4):\n",
    "    print(\"Training Epoch\", i*5)\n",
    "    predictions = classifier.predict(x_test)\n",
    "    accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n",
    "    print(\"Accuracy on benign test examples: {}%\".format(accuracy * 100))\n",
    "    \n",
    "    predictions = classifier.predict(x_test[trigger_test_inds])\n",
    "    b_accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test[trigger_test_inds], axis=1)) / len(trigger_test_inds)\n",
    "    print(\"Accuracy on benign trigger test examples: {}%\".format(b_accuracy * 100))\n",
    "    \n",
    "    predictions = classifier.predict(test_poisoned_samples)\n",
    "    p_accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(test_poisoned_labels,axis=1)) / len(test_poisoned_labels)\n",
    "    print(\"Accuracy on poison trigger test examples: {}%\".format(p_accuracy * 100))\n",
    "    p_success = np.sum(np.argmax(predictions, axis=1) == np.argmax(target)) / len(test_poisoned_labels)\n",
    "    print(\"Success on poison trigger test examples: {}%\".format(p_success * 100))\n",
    "    print()\n",
    "    if i != 0:\n",
    "        for param_group in classifier.optimizer.param_groups:\n",
    "            param_group[\"lr\"] *= lr_factor\n",
    "    classifier.fit(poison_x, poison_y, epochs=5, training_mode=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final Performance\n",
      "Accuracy on benign test examples: 72.36%\n",
      "Accuracy on benign trigger test examples: 39.6%\n",
      "Accuracy on poison trigger test examples: 17.299999999999997%\n",
      "Success on poison trigger test examples: 71.8%\n"
     ]
    }
   ],
   "source": [
    "print(\"Final Performance\")\n",
    "predictions = classifier.predict(x_test)\n",
    "accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test, axis=1)) / len(y_test)\n",
    "print(\"Accuracy on benign test examples: {}%\".format(accuracy * 100))\n",
    "\n",
    "predictions = classifier.predict(x_test[trigger_test_inds])\n",
    "b_accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test[trigger_test_inds], axis=1)) / len(trigger_test_inds)\n",
    "print(\"Accuracy on benign trigger test examples: {}%\".format(b_accuracy * 100))\n",
    "\n",
    "predictions = classifier.predict(test_poisoned_samples)\n",
    "p_accuracy = np.sum(np.argmax(predictions, axis=1) == np.argmax(y_test[trigger_test_inds],axis=1)) / len(trigger_test_inds)\n",
    "print(\"Accuracy on poison trigger test examples: {}%\".format(p_accuracy * 100))\n",
    "p_success = np.sum(np.argmax(predictions, axis=1) == np.argmax(target)) / len(trigger_test_inds)\n",
    "print(\"Success on poison trigger test examples: {}%\".format(p_success * 100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
